# -*- coding: utf-8 -*-
"""
Created on Fri Apr 14 11:25:01 2023

@author: lv
"""

import jieba

dict_path = 'utils/dict.txt'

class JiebaTokenizer:
    def __init__(self):
        self.encode_map = {}
        self.decode_map = {}
        self.vocab_size = 0
        self.pad_token_id = 0
        self.end_token_id = 13
        with open(dict_path, 'r', encoding='utf-8') as f:
            num = 0
            for i, line in enumerate(f):
                num += 1
                #前50位为保留字符
                if i < 50:
                    continue
                line = line.strip()
                self.encode_map[line] = i
                self.decode_map[i] = line
            #0作为停止标记
            self.encode_map['<ENDTOKEN>'] = self.end_token_id
            self.decode_map[self.end_token_id] = ''
            self.encode_map[' '] = 1
            self.decode_map[1] = ' '
            self.encode_map['\t'] = 2
            self.decode_map[2] = '\t'
            self.encode_map['\n'] = 3
            self.decode_map[3] = '\n'
            self.encode_map['\r'] = 4
            self.decode_map[4] = '\r'
            self.vocab_size = num
            #print(self.vocab_size)
    def encode(self, text):
        words = jieba.cut(text,cut_all=True)
        result = []
        for word in words:
            if word in self.encode_map:
                result.append(self.encode_map[word])
            else:
                for ch in word:
                    result.append(self.encode_map[ch])
        return result
        
    def decode(self, indices):
        result = []
        for indice in indices:
            if indice in self.decode_map:
                result.append(self.decode_map[indice])
        return result